from typing import Iterable

import numpy as np
from torch import nn

from Policy.Reward.reward import RewardTerminateTruncate


class RewardTermTruncManager(nn.Module):
    # computes the reward, termination and done signals for a particular policy
    # this should be a function of the observation, goal and time.
    # upper_policy and lower_policies will have this attached to them
    def __init__(self, rtt_functions: Iterable[RewardTerminateTruncate]):
        super().__init__()
        # see Option.Terminate.General.rtt_base for what goes in here
        # these functions manage reward, terminate and trunc
        self.rtt_functions = nn.ModuleList(rtt_functions)
        self.num_updates = self.rtt_functions[0].num_updates

    def check_rew(self, data, idx=0):
        rtt_func = self.rtt_functions[idx]
        rewards = rtt_func.rew(data)
        return rewards

    def check_term(self, data, idx=0):
        rtt_func = self.rtt_functions[idx]
        terms = rtt_func.term(data)
        return terms

    def check_rew_term_trunc(self, data, idx=0):
        rtt_func = self.rtt_functions[idx]

        # TODO: may not work for wide lower (multiple low policies)
        rewards = rtt_func.rew(data)
        terms = rtt_func.term(data)
        timeouts = rtt_func.trunc(data)

        env_terms = data.terminated
        env_truncs = data.truncated

        return rewards, terms | env_terms, timeouts | env_truncs

    def set_updating(self, updating):
        for rtt in self.rtt_functions:
            rtt.updating = updating

    def update_schedules(self):
        for rtt in self.rtt_functions:
            rtt.update_schedules()

    def update_state_counts(self, batch):
        for rtt in self.rtt_functions:
            rtt.update_state_counts(batch)

    def check_reached(self, batch):
        graph_reached, goal_reached = np.ones(len(batch)).astype(bool), np.ones(len(batch)).astype(bool)
        for rtt in self.rtt_functions:
            grar, goar = rtt.check_reached(batch)
            graph_reached, goal_reached = graph_reached & grar.astype(bool), goal_reached & goar.astype(bool)
        return graph_reached.astype(float), goal_reached.astype(float)

    def update_statistics(self, results):
        for rtt in self.rtt_functions:
            rtt.update_statistics(results)

